package com.ruoyi.common.interceptor;

import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.ruoyi.common.annotation.DataTenantDataScope;
import com.ruoyi.common.core.domain.model.LoginUser;
import com.ruoyi.common.utils.SecurityUtils;
import com.ruoyi.common.utils.StringUtils;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.HexValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import net.sf.jsqlparser.statement.select.SetOperationList;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.List;
import java.util.Objects;
/**
 * @author zxg
 * @date 2023-11-06
 * @Desc
 */
@Slf4j
@AllArgsConstructor
@NoArgsConstructor
public class DataTenantInterceptor extends JsqlParserSupport implements InnerInterceptor {

    private boolean tenantFlag;

    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        if(!tenantFlag){
            return;
        }
        if(InterceptorIgnoreHelper.willIgnoreDataPermission(ms.getId())){
            return;
        }
        PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
        log.info("租户查询sql:{}",mpBs.sql());
        mpBs.sql(this.parserSingle(mpBs.sql(),ms.getId()));
    }

    @Override
    protected void processSelect(Select select, int index, String sql, Object obj) {
        if(!tenantFlag){
            return;
        }
        SelectBody selectBody = select.getSelectBody();
        if (selectBody instanceof PlainSelect) {
            this.setWhere((PlainSelect) selectBody, (String) obj);
        } else if (selectBody instanceof SetOperationList) {
            SetOperationList setOperationList = (SetOperationList) selectBody;
            List<SelectBody> selectBodyList = setOperationList.getSelects();
            selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, (String) obj));
        }
    }
    /**
     * 设置 where 条件
     *
     * @param plainSelect  查询对象
     * @param whereSegment 查询条件片段
     */
    private void setWhere(PlainSelect plainSelect, String whereSegment) {

        Expression sqlSegment = this.getSqlSegment(plainSelect, whereSegment);
        if (null != sqlSegment) {
            plainSelect.setWhere(sqlSegment);
        }
    }
    /**
     * 获取数据权限 SQL 片段
     *
     * @param plainSelect  查询对象
     * @param whereSegment 查询条件片段
     * @return JSqlParser 条件表达式
     */
    @SneakyThrows(Exception.class)
    public Expression getSqlSegment(PlainSelect plainSelect, String whereSegment) {
        if(!(plainSelect.getFromItem() instanceof Table)){
            return plainSelect.getWhere();
        }
        // 待执行 SQL Where 条件表达式
        Expression where = plainSelect.getWhere();

        if (where == null) {
            where = new HexValue(" 1 = 1 ");
        }
        log.info("开始进行权限过滤,where: {},mappedStatementId: {}", where, whereSegment);
        //获取mapper名称
        String className = whereSegment.substring(0, whereSegment.lastIndexOf("."));
        //获取方法名
        String methodName = whereSegment.substring(whereSegment.lastIndexOf(".") + 1);

        Table fromItem = (Table) plainSelect.getFromItem();
        // 有别名用别名，无别名用表名，防止字段冲突报错
        Alias fromItemAlias = fromItem.getAlias();
        String mainTableName = fromItemAlias == null ? fromItem.getName() : fromItemAlias.getName();

        //获取当前mapper 的方法
        Method[] methods = Class.forName(className).getMethods();
        log.info("过滤权限类:{},方法：{},表：{},表别名：{}", className, methodName,fromItem.getName(),mainTableName);
        //遍历判断mapper 的所以方法，判断方法上是否有 DataTenantDataScope
        for (Method m : methods) {
            if (Objects.equals(m.getName(), methodName)) {
                DataTenantDataScope annotation = m.getAnnotation(DataTenantDataScope.class);
                if (annotation == null) {
                    return where;
                }
                // 1、当前用户Code
                LoginUser user = SecurityUtils.getLoginUser();
                if(StringUtils.isEmpty(user.getCompanyId())){
                    return where;
                }
                // 查看自己的数据
                //  = 表达式
                EqualsTo usesEqualsTo = new EqualsTo();
                usesEqualsTo.setLeftExpression(new Column(mainTableName + ".company_id"));
                usesEqualsTo.setRightExpression(new StringValue(user.getCompanyId()));
                return new AndExpression(where, usesEqualsTo);
            }
        }
        //说明无权查看，
        //where = new HexValue(" 1 = 2 ");
        return where;
    }

    public boolean isTenantFlag() {
        return tenantFlag;
    }

    public void setTenantFlag(boolean tenantFlag) {
        this.tenantFlag = tenantFlag;
    }
}
